import numpy as np

from scipy.special import comb, digamma, gammaln, psi
from scipy.stats import gamma
# import sparse
import scipy

#from gensim.matutils import dirichlet_expectation, logsumexp

# Import TensorLy
import tensorly as tl
from tensorly import norm

tl.set_backend('numpy')
device = 'cpu'#cuda

def get_ei(length, i):
    '''Get the ith standard basis vector of a given length'''
    e = tl.zeros(length)
    e[i] = 1
    return e

def dirichlet_expectation(alpha):
    '''Normalize alpha using the dirichlet distribution'''
    if len(alpha.shape) == 1:
        result = psi(alpha) - psi(np.sum(alpha))
    else:
        result = psi(alpha) - psi(np.sum(alpha, 1))[:, np.newaxis]
    return result #psi(alpha) - psi(np.sum(alpha))

def smooth_beta(beta, smoothing = 0.01):
    '''Smooth the existing beta so that it all positive (no 0 elements)'''
    smoothed_beta = beta * (1 - smoothing)
    smoothed_beta += (np.ones((beta.shape[0], beta.shape[1])) * (smoothing/beta.shape[0]))

    assert np.all(abs(np.sum(smoothed_beta, axis=0) - 1) <= 1e-6), 'sum not close to 1'
    assert smoothing <= 1e-4 or np.all(smoothed_beta > 1e-10), 'zero values'
    return smoothed_beta

def simplex_proj(V):
    '''Project V onto a simplex'''
    v_len = V.size
    U = np.sort(V)[::-1]
    cums = np.cumsum(U, dtype=float) - 1
    index = np.reciprocal(np.arange(1, v_len+1, dtype=float))
    inter_vec = cums * index
    to_befind_max = U - inter_vec
    max_idx = 0

    for i in range(0, v_len):
        if (to_befind_max[v_len-i-1] > 0):
            max_idx = v_len-i-1
            break
    theta = inter_vec[max_idx]
    p_norm = V - theta
    p_norm[p_norm < 0.0] = 0.0
    return (p_norm, theta)

def perplexity (documents, beta, alpha, gamma, subsample_ratio = 1.0, sp = False):
    '''get perplexity of model, given word count matrix (documents)
    topic/word distribution (beta), weights (alpha), and document/topic
    distribution (gamma)'''

    #elogbeta = dirichlet_expectation(beta)
    #print("elogbeta nan: " + str(np.isnan(elogbeta).any()))
    #print("gamma nan: " + str(np.isnan(gamma).any()))
    #print("alpha nan: " + str(np.isnan(alpha).any()))

    #corpus_part = np.zeros(documents.shape[0])
    log_likelihood = 0.0
    for i, doc in enumerate(documents):
        #doc_bound = 0.0
        gammad = gamma[i]
        #print("elogbeta nan: " + str(numpy.isnan(elogbeta).any()))
        elogthetad = dirichlet_expectation(gammad)
        #print("elogthetad nan: " + str(np.isnan(elogthetad).any()))
        
        if sp == True:
            for idx in sparse.COO.nonzero(doc)[0]:
                log_likelihood += doc[idx] * logsumexp(elogthetad + beta[idx].T)
        else:
            #doc_bound += np.dot(doc, logsumexp(elogthetad + elogbeta
            #print("before first add: " + str(log_likelihood))
            for idx in np.nonzero(doc)[0]:
                log_likelihood += doc[idx] * logsumexp(elogthetad + beta[idx].T)
            #print("after first add: " + str(log_likelihood))
        
        log_likelihood += np.sum((alpha - gammad) * elogthetad)
        #print("after alpha - gamma: " + str(log_likelihood))
        log_likelihood += np.sum(gammaln(gammad) - gammaln(alpha))
        #print("after gammaln(gamma) - alpha: " + str(log_likelihood))
        log_likelihood += gammaln(np.sum(alpha)) - gammaln(np.sum(gammad))
        #print("after last add: " + str(log_likelihood))
        #corpus_part[i] = doc_bound

    #sum the log likelihood of all the documents to get total log likelihood
    #log_likelihood = np.sum(corpus_part)
    
    #total_words = np.sum(documents)

    #perplexity is - log likelihood / total number of words in corpus
    return log_likelihood*subsample_ratio #(-1*log_likelihood / total_words)

def doc_likelihood(X, theta, alpha, beta, subsample_ratio = 1.0):
    '''X = document/word, theta = document/topic matrix, alpha = weights, beta = topic/word distribution'''
    alpha_minus = alpha - 1.
    n_doc = X.shape[0]

    loglikelihood = (np.log(theta) * alpha_minus).sum()
    for idx_d in range(n_doc):
        ids = np.nonzero(X[idx_d, :])[0]
        cnts = X[idx_d, ids]

        theta_d = theta[idx_d, :]
        beta_d = beta[:, ids]

        loglikelihood += (np.log(np.dot(theta_d, beta_d)) * cnts).sum()
    return loglikelihood*subsample_ratio

def logsumexp(x):
    '''calculate log(sum(exp(x)))'''
    a = np.max(x)
    return a + np.log(np.sum(np.exp(x - a)))
